"""
PCBA dataset loader.
"""
import os
import deepchem as dc
from deepchem.molnet.load_function.molnet_loader import TransformerGenerator, _MolnetLoader
from deepchem.data import Dataset
from typing import List, Optional, Tuple, Union

PCBA_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/%s"
PCBA_TASKS = [
    'PCBA-1030', 'PCBA-1379', 'PCBA-1452', 'PCBA-1454', 'PCBA-1457',
    'PCBA-1458', 'PCBA-1460', 'PCBA-1461', 'PCBA-1468', 'PCBA-1469',
    'PCBA-1471', 'PCBA-1479', 'PCBA-1631', 'PCBA-1634', 'PCBA-1688',
    'PCBA-1721', 'PCBA-2100', 'PCBA-2101', 'PCBA-2147', 'PCBA-2242',
    'PCBA-2326', 'PCBA-2451', 'PCBA-2517', 'PCBA-2528', 'PCBA-2546',
    'PCBA-2549', 'PCBA-2551', 'PCBA-2662', 'PCBA-2675', 'PCBA-2676', 'PCBA-411',
    'PCBA-463254', 'PCBA-485281', 'PCBA-485290', 'PCBA-485294', 'PCBA-485297',
    'PCBA-485313', 'PCBA-485314', 'PCBA-485341', 'PCBA-485349', 'PCBA-485353',
    'PCBA-485360', 'PCBA-485364', 'PCBA-485367', 'PCBA-492947', 'PCBA-493208',
    'PCBA-504327', 'PCBA-504332', 'PCBA-504333', 'PCBA-504339', 'PCBA-504444',
    'PCBA-504466', 'PCBA-504467', 'PCBA-504706', 'PCBA-504842', 'PCBA-504845',
    'PCBA-504847', 'PCBA-504891', 'PCBA-540276', 'PCBA-540317', 'PCBA-588342',
    'PCBA-588453', 'PCBA-588456', 'PCBA-588579', 'PCBA-588590', 'PCBA-588591',
    'PCBA-588795', 'PCBA-588855', 'PCBA-602179', 'PCBA-602233', 'PCBA-602310',
    'PCBA-602313', 'PCBA-602332', 'PCBA-624170', 'PCBA-624171', 'PCBA-624173',
    'PCBA-624202', 'PCBA-624246', 'PCBA-624287', 'PCBA-624288', 'PCBA-624291',
    'PCBA-624296', 'PCBA-624297', 'PCBA-624417', 'PCBA-651635', 'PCBA-651644',
    'PCBA-651768', 'PCBA-651965', 'PCBA-652025', 'PCBA-652104', 'PCBA-652105',
    'PCBA-652106', 'PCBA-686970', 'PCBA-686978', 'PCBA-686979', 'PCBA-720504',
    'PCBA-720532', 'PCBA-720542', 'PCBA-720551', 'PCBA-720553', 'PCBA-720579',
    'PCBA-720580', 'PCBA-720707', 'PCBA-720708', 'PCBA-720709', 'PCBA-720711',
    'PCBA-743255', 'PCBA-743266', 'PCBA-875', 'PCBA-881', 'PCBA-883',
    'PCBA-884', 'PCBA-885', 'PCBA-887', 'PCBA-891', 'PCBA-899', 'PCBA-902',
    'PCBA-903', 'PCBA-904', 'PCBA-912', 'PCBA-914', 'PCBA-915', 'PCBA-924',
    'PCBA-925', 'PCBA-926', 'PCBA-927', 'PCBA-938', 'PCBA-995'
]


class _PCBALoader(_MolnetLoader):

    def __init__(self, assay_file_name: str,
                 featurizer: Union[dc.feat.Featurizer,
                                   str], splitter: Union[dc.splits.Splitter,
                                                         str, None],
                 transformer_generators: List[Union[TransformerGenerator,
                                                    str]], tasks: List[str],
                 data_dir: Optional[str], save_dir: Optional[str], **kwargs):
        super(_PCBALoader,
              self).__init__(featurizer, splitter, transformer_generators,
                             tasks, data_dir, save_dir)
        self.assay_file_name = assay_file_name

    def create_dataset(self) -> Dataset:
        dataset_file = os.path.join(self.data_dir, self.assay_file_name)
        if not os.path.exists(dataset_file):
            dc.utils.data_utils.download_url(url=PCBA_URL %
                                             self.assay_file_name,
                                             dest_dir=self.data_dir)
        loader = dc.data.CSVLoader(tasks=self.tasks,
                                   feature_field="smiles",
                                   featurizer=self.featurizer)
        return loader.create_dataset(dataset_file)


def load_pcba(
    featurizer: Union[dc.feat.Featurizer, str] = 'ECFP',
    splitter: Union[dc.splits.Splitter, str, None] = 'scaffold',
    transformers: List[Union[TransformerGenerator, str]] = ['balancing'],
    reload: bool = True,
    data_dir: Optional[str] = None,
    save_dir: Optional[str] = None,
    **kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
    """Load PCBA dataset

    PubChem BioAssay (PCBA) is a database consisting of biological activities of
    small molecules generated by high-throughput screening. We use a subset of
    PCBA, containing 128 bioassays measured over 400 thousand compounds,
    used by previous work to benchmark machine learning methods.

    Random splitting is recommended for this dataset.

    The raw data csv file contains columns below:

    - "mol_id" - PubChem CID of the compound
    - "smiles" - SMILES representation of the molecular structure
    - "PCBA-XXX" - Measured results (Active/Inactive) for bioassays:
        search for the assay ID at
        https://pubchem.ncbi.nlm.nih.gov/search/#collection=bioassays
        for details

    Parameters
    ----------
    featurizer: Featurizer or str
        the featurizer to use for processing the data.  Alternatively you can pass
        one of the names from dc.molnet.featurizers as a shortcut.
    splitter: Splitter or str
        the splitter to use for splitting the data into training, validation, and
        test sets.  Alternatively you can pass one of the names from
        dc.molnet.splitters as a shortcut.  If this is None, all the data
        will be included in a single dataset.
    transformers: list of TransformerGenerators or strings
        the Transformers to apply to the data.  Each one is specified by a
        TransformerGenerator or, as a shortcut, one of the names from
        dc.molnet.transformers.
    reload: bool
        if True, the first call for a particular featurizer and splitter will cache
        the datasets to disk, and subsequent calls will reload the cached datasets.
    data_dir: str
        a directory to save the raw data in
    save_dir: str
        a directory to save the dataset in

    References
    ----------
    .. [1] Wang, Yanli, et al. "PubChem's BioAssay database."
        Nucleic acids research 40.D1 (2011): D400-D412.
    """
    loader = _PCBALoader('pcba.csv.gz', featurizer, splitter, transformers,
                         PCBA_TASKS, data_dir, save_dir, **kwargs)
    return loader.load_dataset('pcba', reload)


# def load_pcba_146(featurizer='ECFP',
#                   split='random',
#                   reload=True,
#                   data_dir=None,
#                   save_dir=None,
#                   **kwargs):
#   return load_pcba_dataset(
#       featurizer=featurizer,
#       split=split,
#       reload=reload,
#       assay_file_name="pcba_146.csv.gz",
#       data_dir=data_dir,
#       save_dir=save_dir,
#       **kwargs)

# def load_pcba_2475(featurizer='ECFP',
#                    split='random',
#                    reload=True,
#                    data_dir=None,
#                    save_dir=None,
#                    **kwargs):
#   return load_pcba_dataset(
#       featurizer=featurizer,
#       split=split,
#       reload=reload,
#       assay_file_name="pcba_2475.csv.gz",
#       data_dir=data_dir,
#       save_dir=save_dir,
#       **kwargs)
